{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5c6594-e5e8-4966-8cb8-a3e6a9ed7d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._base_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fce0c950-2e03-4be1-95d4-a02409d8dba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c7c2ba5-19ee-421e-9252-7224b03f5201",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import inspect\n",
    "import random\n",
    "import warnings\n",
    "from contextlib import contextmanager\n",
    "from copy import deepcopy\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import fsspec\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "\n",
    "from neuralforecast.tsdataset import (\n",
    "    TimeSeriesDataModule,\n",
    "    TimeSeriesDataset,\n",
    "    _DistributedTimeSeriesDataModule,\n",
    ")\n",
    "from neuralforecast.losses.pytorch import IQLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d4c4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dataclass\n",
    "class DistributedConfig:\n",
    "    partitions_path: str\n",
    "    num_nodes: int\n",
    "    devices: int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5197e340-11f1-4c8c-96d1-ed396ac2b710",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "@contextmanager\n",
    "def _disable_torch_init():\n",
    "    \"\"\"Context manager used to disable pytorch's weight initialization.\n",
    "\n",
    "    This is especially useful when loading saved models, since when initializing\n",
    "    a model the weights are also initialized following some method\n",
    "    (e.g. kaiming uniform), and that time is wasted since we'll override them with\n",
    "    the saved weights.\"\"\"\n",
    "    def noop(*args, **kwargs):\n",
    "        return\n",
    "        \n",
    "    kaiming_uniform = nn.init.kaiming_uniform_\n",
    "    kaiming_normal = nn.init.kaiming_normal_\n",
    "    xavier_uniform = nn.init.xavier_uniform_\n",
    "    xavier_normal = nn.init.xavier_normal_\n",
    "    \n",
    "    nn.init.kaiming_uniform_ = noop\n",
    "    nn.init.kaiming_normal_ = noop\n",
    "    nn.init.xavier_uniform_ = noop\n",
    "    nn.init.xavier_normal_ = noop\n",
    "    try:\n",
    "        yield\n",
    "    finally:\n",
    "        nn.init.kaiming_uniform_ = kaiming_uniform\n",
    "        nn.init.kaiming_normal_ = kaiming_normal\n",
    "        nn.init.xavier_uniform_ = xavier_uniform\n",
    "        nn.init.xavier_normal_ = xavier_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60c40a64-8381-46a2-8cbb-70ec70ed7914",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class BaseModel(pl.LightningModule):\n",
    "    EXOGENOUS_FUTR = True\n",
    "    EXOGENOUS_HIST = True\n",
    "    EXOGENOUS_STAT = True\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        random_seed,\n",
    "        loss,\n",
    "        valid_loss,\n",
    "        optimizer,\n",
    "        optimizer_kwargs,\n",
    "        lr_scheduler,\n",
    "        lr_scheduler_kwargs,\n",
    "        futr_exog_list,\n",
    "        hist_exog_list,\n",
    "        stat_exog_list,\n",
    "        max_steps,\n",
    "        early_stop_patience_steps,\n",
    "        **trainer_kwargs,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        with warnings.catch_warnings(record=False):\n",
    "            warnings.filterwarnings('ignore')\n",
    "            # the following line issues a warning about the loss attribute being saved\n",
    "            # but we do want to save it\n",
    "            self.save_hyperparameters() # Allows instantiation from a checkpoint from class\n",
    "        self.random_seed = random_seed\n",
    "        pl.seed_everything(self.random_seed, workers=True)\n",
    "\n",
    "        # Loss\n",
    "        self.loss = loss\n",
    "        if valid_loss is None:\n",
    "            self.valid_loss = loss\n",
    "        else:\n",
    "            self.valid_loss = valid_loss\n",
    "        self.train_trajectories = []\n",
    "        self.valid_trajectories = []\n",
    "\n",
    "        # Optimization\n",
    "        if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
    "            raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
    "        self.optimizer = optimizer\n",
    "        self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
    "\n",
    "        # lr scheduler\n",
    "        if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
    "            raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
    "        self.lr_scheduler = lr_scheduler\n",
    "        self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
    "\n",
    "\n",
    "        # Variables\n",
    "        self.futr_exog_list = list(futr_exog_list) if futr_exog_list is not None else []\n",
    "        self.hist_exog_list = list(hist_exog_list) if hist_exog_list is not None else []\n",
    "        self.stat_exog_list = list(stat_exog_list) if stat_exog_list is not None else []\n",
    "\n",
    "        # Set data sizes\n",
    "        self.futr_exog_size = len(self.futr_exog_list)\n",
    "        self.hist_exog_size = len(self.hist_exog_list)\n",
    "        self.stat_exog_size = len(self.stat_exog_list)   \n",
    "\n",
    "        # Check if model supports exogenous, otherwise raise Exception\n",
    "        if not self.EXOGENOUS_FUTR and self.futr_exog_size > 0:\n",
    "            raise Exception(f'{type(self).__name__} does not support future exogenous variables.')\n",
    "        if not self.EXOGENOUS_HIST and self.hist_exog_size > 0:\n",
    "            raise Exception(f'{type(self).__name__} does not support historical exogenous variables.')\n",
    "        if not self.EXOGENOUS_STAT and self.stat_exog_size > 0:\n",
    "            raise Exception(f'{type(self).__name__} does not support static exogenous variables.')\n",
    "\n",
    "        # Implicit Quantile Loss\n",
    "        if isinstance(self.loss, IQLoss):\n",
    "            if not isinstance(self.valid_loss, IQLoss):\n",
    "                raise Exception('Please set valid_loss to IQLoss() when training with IQLoss')\n",
    "        if isinstance(self.valid_loss, IQLoss) and not isinstance(self.loss, IQLoss):\n",
    "            raise Exception('Please set loss to IQLoss() when validating with IQLoss')        \n",
    "\n",
    "        ## Trainer arguments ##\n",
    "        # Max steps, validation steps and check_val_every_n_epoch\n",
    "        trainer_kwargs = {**trainer_kwargs, 'max_steps': max_steps}\n",
    "\n",
    "        if 'max_epochs' in trainer_kwargs.keys():\n",
    "            raise Exception('max_epochs is deprecated, use max_steps instead.')\n",
    "\n",
    "        # Callbacks\n",
    "        if early_stop_patience_steps > 0:\n",
    "            if 'callbacks' not in trainer_kwargs:\n",
    "                trainer_kwargs['callbacks'] = []\n",
    "            trainer_kwargs['callbacks'].append(\n",
    "                EarlyStopping(\n",
    "                    monitor='ptl/val_loss', patience=early_stop_patience_steps\n",
    "                )\n",
    "            )\n",
    "\n",
    "        # Add GPU accelerator if available\n",
    "        if trainer_kwargs.get('accelerator', None) is None:\n",
    "            if torch.cuda.is_available():\n",
    "                trainer_kwargs['accelerator'] = \"gpu\"\n",
    "        if trainer_kwargs.get('devices', None) is None:\n",
    "            if torch.cuda.is_available():\n",
    "                trainer_kwargs['devices'] = -1\n",
    "\n",
    "        # Avoid saturating local memory, disabled fit model checkpoints\n",
    "        if trainer_kwargs.get('enable_checkpointing', None) is None:\n",
    "            trainer_kwargs['enable_checkpointing'] = False\n",
    "\n",
    "        self.trainer_kwargs = trainer_kwargs\n",
    "\n",
    "    def __repr__(self):\n",
    "        return type(self).__name__ if self.alias is None else self.alias\n",
    "\n",
    "    def _check_exog(self, dataset):\n",
    "        temporal_cols = set(dataset.temporal_cols.tolist())\n",
    "        static_cols = set(dataset.static_cols.tolist() if dataset.static_cols is not None else [])\n",
    "\n",
    "        missing_hist = set(self.hist_exog_list) - temporal_cols\n",
    "        missing_futr = set(self.futr_exog_list) - temporal_cols\n",
    "        missing_stat = set(self.stat_exog_list) - static_cols\n",
    "        if missing_hist:\n",
    "            raise Exception(f'{missing_hist} historical exogenous variables not found in input dataset')\n",
    "        if missing_futr:\n",
    "            raise Exception(f'{missing_futr} future exogenous variables not found in input dataset')\n",
    "        if missing_stat:\n",
    "            raise Exception(f'{missing_stat} static exogenous variables not found in input dataset')\n",
    "\n",
    "    def _restart_seed(self, random_seed):\n",
    "        if random_seed is None:\n",
    "            random_seed = self.random_seed\n",
    "        torch.manual_seed(random_seed)\n",
    "\n",
    "    def _get_temporal_exogenous_cols(self, temporal_cols):\n",
    "        return list(\n",
    "            set(temporal_cols.tolist()) & set(self.hist_exog_list + self.futr_exog_list)\n",
    "        )\n",
    "    \n",
    "    def _set_quantile_for_iqloss(self, **data_module_kwargs):\n",
    "        if \"quantile\" in data_module_kwargs:\n",
    "            if not isinstance(self.loss, IQLoss):\n",
    "                raise Exception(\n",
    "                    \"Please train with loss=IQLoss() to make use of the quantile argument.\"\n",
    "                )\n",
    "            else:\n",
    "                self.quantile = data_module_kwargs[\"quantile\"]\n",
    "                data_module_kwargs.pop(\"quantile\")\n",
    "                self.loss.update_quantile(q=self.quantile)\n",
    "        elif isinstance(self.loss, IQLoss):\n",
    "            self.quantile = 0.5\n",
    "            self.loss.update_quantile(q=self.quantile)\n",
    "\n",
    "        return data_module_kwargs\n",
    "\n",
    "    def _fit_distributed(\n",
    "        self,\n",
    "        distributed_config,\n",
    "        datamodule,\n",
    "        val_size,\n",
    "        test_size,\n",
    "    ):\n",
    "        assert distributed_config is not None\n",
    "        from pyspark.ml.torch.distributor import TorchDistributor\n",
    "\n",
    "        def train_fn(\n",
    "            model_cls,\n",
    "            model_params,\n",
    "            datamodule,\n",
    "            trainer_kwargs,\n",
    "            num_tasks,\n",
    "            num_proc_per_task,\n",
    "            val_size,\n",
    "            test_size,\n",
    "        ):\n",
    "            import pytorch_lightning as pl\n",
    "\n",
    "            # we instantiate here to avoid pickling large tensors (weights)\n",
    "            model = model_cls(**model_params)\n",
    "            model.val_size = val_size\n",
    "            model.test_size = test_size\n",
    "            for arg in ('devices', 'num_nodes'):\n",
    "                trainer_kwargs.pop(arg, None)\n",
    "            trainer = pl.Trainer(\n",
    "                strategy=\"ddp\",\n",
    "                use_distributed_sampler=False,  # to ensure our dataloaders are used as-is\n",
    "                num_nodes=num_tasks,\n",
    "                devices=num_proc_per_task,\n",
    "                **trainer_kwargs,\n",
    "            )\n",
    "            trainer.fit(model=model, datamodule=datamodule)\n",
    "            model.metrics = trainer.callback_metrics\n",
    "            model.__dict__.pop('_trainer', None)\n",
    "            return model\n",
    "\n",
    "        def is_gpu_accelerator(accelerator):\n",
    "            from pytorch_lightning.accelerators.cuda import CUDAAccelerator\n",
    "\n",
    "            return (\n",
    "                accelerator == \"gpu\"\n",
    "                or isinstance(accelerator, CUDAAccelerator)\n",
    "                or (accelerator == \"auto\" and CUDAAccelerator.is_available())\n",
    "            )\n",
    "\n",
    "        local_mode = distributed_config.num_nodes == 1\n",
    "        if local_mode:\n",
    "            num_tasks = 1\n",
    "            num_proc_per_task = distributed_config.devices\n",
    "        else:\n",
    "            num_tasks = distributed_config.num_nodes * distributed_config.devices\n",
    "            num_proc_per_task = 1  # number of GPUs per task\n",
    "        num_proc = num_tasks * num_proc_per_task\n",
    "        use_gpu = is_gpu_accelerator(self.trainer_kwargs[\"accelerator\"])\n",
    "        model = TorchDistributor(\n",
    "            num_processes=num_proc,\n",
    "            local_mode=local_mode,\n",
    "            use_gpu=use_gpu,\n",
    "        ).run(\n",
    "            train_fn,\n",
    "            model_cls=type(self),\n",
    "            model_params=self.hparams,\n",
    "            datamodule=datamodule,\n",
    "            trainer_kwargs=self.trainer_kwargs,\n",
    "            num_tasks=num_tasks,\n",
    "            num_proc_per_task=num_proc_per_task,\n",
    "            val_size=val_size,\n",
    "            test_size=test_size,\n",
    "        )\n",
    "        return model\n",
    "\n",
    "    def _fit(\n",
    "        self,\n",
    "        dataset,\n",
    "        batch_size,\n",
    "        valid_batch_size=1024,\n",
    "        val_size=0,\n",
    "        test_size=0,\n",
    "        random_seed=None,\n",
    "        shuffle_train=True,\n",
    "        distributed_config=None,\n",
    "    ):\n",
    "        self._check_exog(dataset)\n",
    "        self._restart_seed(random_seed)\n",
    "\n",
    "        self.val_size = val_size\n",
    "        self.test_size = test_size\n",
    "        is_local = isinstance(dataset, TimeSeriesDataset)\n",
    "        if is_local:\n",
    "            datamodule_constructor = TimeSeriesDataModule\n",
    "        else:\n",
    "            datamodule_constructor = _DistributedTimeSeriesDataModule\n",
    "        datamodule = datamodule_constructor(\n",
    "            dataset=dataset, \n",
    "            batch_size=batch_size,\n",
    "            valid_batch_size=valid_batch_size,\n",
    "            num_workers=self.num_workers_loader,\n",
    "            drop_last=self.drop_last_loader,\n",
    "            shuffle_train=shuffle_train,\n",
    "        )\n",
    "\n",
    "        if self.val_check_steps > self.max_steps:\n",
    "            warnings.warn(\n",
    "                'val_check_steps is greater than max_steps, '\n",
    "                'setting val_check_steps to max_steps.'\n",
    "            )\n",
    "        val_check_interval = min(self.val_check_steps, self.max_steps)\n",
    "        self.trainer_kwargs['val_check_interval'] = int(val_check_interval)\n",
    "        self.trainer_kwargs['check_val_every_n_epoch'] = None\n",
    "\n",
    "        if is_local:\n",
    "            model = self\n",
    "            trainer = pl.Trainer(**model.trainer_kwargs)\n",
    "            trainer.fit(model, datamodule=datamodule)\n",
    "            model.metrics = trainer.callback_metrics\n",
    "            model.__dict__.pop('_trainer', None)\n",
    "        else:\n",
    "            model = self._fit_distributed(\n",
    "                distributed_config,\n",
    "                datamodule,\n",
    "                val_size,\n",
    "                test_size,\n",
    "            )\n",
    "        return model\n",
    "\n",
    "    def on_fit_start(self):\n",
    "        torch.manual_seed(self.random_seed)\n",
    "        np.random.seed(self.random_seed)\n",
    "        random.seed(self.random_seed)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        if self.optimizer:\n",
    "            optimizer_signature = inspect.signature(self.optimizer)\n",
    "            optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
    "            if 'lr' in optimizer_signature.parameters:\n",
    "                if 'lr' in optimizer_kwargs:\n",
    "                    warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
    "                optimizer_kwargs['lr'] = self.learning_rate\n",
    "            optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
    "        else:\n",
    "            if self.optimizer_kwargs:\n",
    "                warnings.warn(\n",
    "                    \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
    "                )            \n",
    "            optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
    "        \n",
    "        lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
    "        if self.lr_scheduler:\n",
    "            lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
    "            lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
    "            if 'optimizer' in lr_scheduler_signature.parameters:\n",
    "                if 'optimizer' in lr_scheduler_kwargs:\n",
    "                    warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
    "                    del lr_scheduler_kwargs['optimizer']\n",
    "            lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
    "        else:\n",
    "            if self.lr_scheduler_kwargs:\n",
    "                warnings.warn(\n",
    "                    \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
    "                )            \n",
    "            lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
    "                optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
    "            )\n",
    "        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
    "\n",
    "    def get_test_size(self):\n",
    "        return self.test_size\n",
    "\n",
    "    def set_test_size(self, test_size):\n",
    "        self.test_size = test_size\n",
    "\n",
    "    def on_validation_epoch_end(self):\n",
    "        if self.val_size == 0:\n",
    "            return\n",
    "        losses = torch.stack(self.validation_step_outputs)\n",
    "        avg_loss = losses.mean().item()\n",
    "        self.log(\n",
    "            \"ptl/val_loss\",\n",
    "            avg_loss,\n",
    "            batch_size=losses.size(0),\n",
    "            sync_dist=True,\n",
    "        )\n",
    "        self.valid_trajectories.append((self.global_step, avg_loss))\n",
    "        self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch)\n",
    "\n",
    "    def save(self, path):\n",
    "        with fsspec.open(path, 'wb') as f:\n",
    "            torch.save(\n",
    "                {'hyper_parameters': self.hparams, 'state_dict': self.state_dict()},\n",
    "                f,\n",
    "            )\n",
    "\n",
    "    @classmethod\n",
    "    def load(cls, path, **kwargs):\n",
    "        with fsspec.open(path, 'rb') as f:\n",
    "            content = torch.load(f, **kwargs)\n",
    "        with _disable_torch_init():\n",
    "            model = cls(**content['hyper_parameters']) \n",
    "        if \"assign\" in inspect.signature(model.load_state_dict).parameters:\n",
    "            model.load_state_dict(content[\"state_dict\"], strict=True, assign=True)\n",
    "        else:  # pytorch<2.1\n",
    "            model.load_state_dict(content[\"state_dict\"], strict=True)\n",
    "        return model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
